import torch
from torch import nn
import torchaudio
from torchvision import models

class Model_zoo(nn.Module):
    def __init__(self, mt, device=None, parser1=None):
        super(Model_zoo, self).__init__()
        self.mt = mt
        self.device = device
        self.n_m = {
             "ResNet18": models.resnet18(False, True),
            #"lcnn_spec": SPECLCNN(),
        }
        self.model = self.n_m[self.mt]
        self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(1000, 2)


    def forward(self, x, Freq_aug=False):
        x=torch.transpose(x,1,2)
        x = x.unsqueeze(dim=1)# (1,mmmm,num_filter)
        # 将1d矩阵变成3D矩阵
        x = self.conv(x)#(1,1,mmm,num_filter)
        x = self.model(x)
        # 将1000分类变成2分类
        x = self.fc(x)

        return x

from ASVBaselineTool.DEMONEED import *
def demo():
    USE_FEATURE=SPECfeature40000
    #测试resnet的网络畅通性
    for feature in [USE_FEATURE]:
        net = Model_zoo("ResNet18")
        # 增加batch维度
        # feature = feature.unsqueeze(dim=0)
        print(net(feature).shape)
        print("resnet 模型 spec 特征输出正常")

if __name__=="__main__":
    demo()